{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 写入 tfrecord 文件"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.5/dist-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
      "  from ._conv import register_converters as _register_converters\n",
      "100%|██████████| 30/30 [00:00<00:00, 8203.75it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "There are 15 sample writen into writer1\n",
      "Finished.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# -*- coding:utf-8 -*- \n",
    "\n",
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "\n",
    "'''tfrecord 写入序列数据，每个样本的长度不固定。\n",
    "主要参考：\n",
    "https://github.com/tensorflow/magenta/blob/master/magenta/common/sequence_example_lib.py\n",
    "其他参考: \n",
    "https://github.com/dennybritz/tf-rnn\n",
    "http://leix.me/2017/01/09/tensorflow-practical-guides/\n",
    "https://github.com/siavash9000/im2txt_demo/blob/master/im2txt/im2txt/ops/inputs.py\n",
    "'''\n",
    "\n",
    "# **1.创建文件\n",
    "writer1 = tf.python_io.TFRecordWriter('../data/seq_test1.tfrecord')\n",
    "writer2 = tf.python_io.TFRecordWriter('../data/seq_test2.tfrecord')\n",
    "\n",
    "# 非序列数据\n",
    "labels = [1, 2, 3, 4, 5,\n",
    "          1, 2, 3, 4, 5, \n",
    "          1, 2, 3, 4, 5, \n",
    "          1, 2, 3, 4, 5, \n",
    "          1, 2, 3, 4, 5, \n",
    "          1, 2, 3, 4, 5]\n",
    "# 长度不固定的序列\n",
    "frames = [[1], [2, 2], [3, 3, 3], [4, 4, 4, 4], [5, 5, 5, 5, 5],\n",
    "          [1], [2, 2], [3, 3, 3], [4, 4, 4, 4], [5, 5, 5, 5, 5],\n",
    "          [1], [2, 2], [3, 3, 3], [4, 4, 4, 4], [5, 5, 5, 5, 5],\n",
    "          [1], [2, 2], [3, 3, 3], [4, 4, 4, 4], [5, 5, 5, 5, 5],\n",
    "          [1], [2, 2], [3, 3, 3], [4, 4, 4, 4], [5, 5, 5, 5, 5],\n",
    "          [1], [2, 2], [3, 3, 3], [4, 4, 4, 4], [5, 5, 5, 5, 5]]\n",
    "\n",
    "writer = writer1\n",
    "for i in tqdm(range(len(labels))):  # **2.对于每个样本\n",
    "    if i == len(labels) / 2:\n",
    "        writer = writer2\n",
    "        print('\\nThere are %d sample writen into writer1' % i)\n",
    "    label = labels[i]\n",
    "    frame = frames[i]\n",
    "    # 非序列化\n",
    "    label_feature = tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))\n",
    "    # 序列化\n",
    "    frame_feature = [\n",
    "        tf.train.Feature(int64_list=tf.train.Int64List(value=[frame_])) for frame_ in frame\n",
    "    ]\n",
    "    \n",
    "    seq_example = tf.train.SequenceExample(\n",
    "        # context 来放置非序列化部分\n",
    "        context=tf.train.Features(feature={\n",
    "            \"label\": label_feature\n",
    "        }),\n",
    "        # feature_lists 放置变长序列\n",
    "        feature_lists=tf.train.FeatureLists(feature_list={\n",
    "            \"frame\": tf.train.FeatureList(feature=frame_feature),\n",
    "        })\n",
    "    )\n",
    "    serialized = seq_example.SerializeToString()\n",
    "    writer.write(serialized)  # **4.写入文件中\n",
    "\n",
    "print('Finished.')\n",
    "writer1.close()\n",
    "writer2.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 读取batch\n",
    "先 restart kernel， 再执行下面命令。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.5/dist-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
      "  from ._conv import register_converters as _register_converters\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tensor(\"ParseSingleSequenceExample/ParseSingleSequenceExample:0\", shape=(), dtype=int64)\n",
      "Tensor(\"ParseSingleSequenceExample/ParseSingleSequenceExample:1\", shape=(?,), dtype=int64)\n"
     ]
    }
   ],
   "source": [
    "# -*- coding:utf-8 -*- \n",
    "\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "\n",
    "'''read seq_data'''\n",
    "\n",
    "# output file name string to a queue\n",
    "tfrecord_file_names = ['../data/seq_test1.tfrecord', '../data/seq_test2.tfrecord']\n",
    "\n",
    "filename_queue = tf.train.string_input_producer(tfrecord_file_names, shuffle=True, capacity=2)\n",
    "reader = tf.TFRecordReader()\n",
    "_, serialized_example = reader.read(filename_queue)\n",
    "context_features = {\n",
    "    \"label\": tf.FixedLenFeature([], dtype=tf.int64)\n",
    "}\n",
    "sequence_features = {\n",
    "    \"frame\": tf.FixedLenSequenceFeature([], dtype=tf.int64)\n",
    "}\n",
    "context_parsed, sequence_parsed = tf.parse_single_sequence_example(\n",
    "    serialized=serialized_example,\n",
    "    context_features=context_features,\n",
    "    sequence_features=sequence_features\n",
    ")\n",
    "\n",
    "label = context_parsed['label']\n",
    "frame = sequence_parsed['frame']\n",
    "\n",
    "print(label)\n",
    "print(frame)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tensor(\"batch:1\", shape=(10, ?), dtype=int64) Tensor(\"batch:0\", shape=(10,), dtype=int64)\n",
      "** batch 0\n",
      "[1 2 3 4 1 2 3 5 4 5]\n",
      "[[1 0 0 0 0]\n",
      " [2 2 0 0 0]\n",
      " [3 3 3 0 0]\n",
      " [4 4 4 4 0]\n",
      " [1 0 0 0 0]\n",
      " [2 2 0 0 0]\n",
      " [3 3 3 0 0]\n",
      " [5 5 5 5 5]\n",
      " [4 4 4 4 0]\n",
      " [5 5 5 5 5]]\n",
      "** batch 1\n",
      "[2 1 3 4 5 1 2 3 4 5]\n",
      "[[2 2 0 0 0]\n",
      " [1 0 0 0 0]\n",
      " [3 3 3 0 0]\n",
      " [4 4 4 4 0]\n",
      " [5 5 5 5 5]\n",
      " [1 0 0 0 0]\n",
      " [2 2 0 0 0]\n",
      " [3 3 3 0 0]\n",
      " [4 4 4 4 0]\n",
      " [5 5 5 5 5]]\n"
     ]
    }
   ],
   "source": [
    "label_batch, frame_batch = tf.train.batch(\n",
    "    [label, frame],\n",
    "    batch_size=10,\n",
    "    num_threads=4,\n",
    "    capacity=500,\n",
    "    dynamic_pad=True,\n",
    "    allow_smaller_final_batch=False\n",
    ")\n",
    "\n",
    "print(frame_batch, label_batch)\n",
    "\n",
    "config = tf.ConfigProto()\n",
    "config.gpu_options.allow_growth = True\n",
    "sess = tf.Session(config=config)\n",
    "tf.train.start_queue_runners(sess=sess)\n",
    "for i in range(2):\n",
    "    _frames_batch, _label_batch = sess.run([frame_batch, label_batch])\n",
    "    print('** batch %d' % i)\n",
    "    print(_label_batch)\n",
    "    print(_frames_batch)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### shuffle 读取batch\n",
    "**restart kenerl** 再继续"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.5/dist-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
      "  from ._conv import register_converters as _register_converters\n"
     ]
    }
   ],
   "source": [
    "# -*- coding:utf-8 -*-\n",
    "\n",
    "import tensorflow as tf\n",
    "import math\n",
    "\n",
    "QUEUE_CAPACITY = 100\n",
    "SHUFFLE_MIN_AFTER_DEQUEUE = QUEUE_CAPACITY // 5\n",
    "\n",
    "\n",
    "def _shuffle_inputs(input_tensors, capacity, min_after_dequeue, num_threads):\n",
    "    \"\"\"Shuffles tensors in `input_tensors`, maintaining grouping.\"\"\"\n",
    "    shuffle_queue = tf.RandomShuffleQueue(\n",
    "        capacity, min_after_dequeue, dtypes=[t.dtype for t in input_tensors])\n",
    "    enqueue_op = shuffle_queue.enqueue(input_tensors)\n",
    "    runner = tf.train.QueueRunner(shuffle_queue, [enqueue_op] * num_threads)\n",
    "    tf.train.add_queue_runner(runner)\n",
    "\n",
    "    output_tensors = shuffle_queue.dequeue()\n",
    "\n",
    "    for i in range(len(input_tensors)):\n",
    "        output_tensors[i].set_shape(input_tensors[i].shape)\n",
    "\n",
    "    return output_tensors\n",
    "\n",
    "\n",
    "def get_padded_batch(file_list, batch_size, num_enqueuing_threads=4, shuffle=False):\n",
    "    \"\"\"Reads batches of SequenceExamples from TFRecords and pads them.\n",
    "\n",
    "    Can deal with variable length SequenceExamples by padding each batch to the\n",
    "    length of the longest sequence with zeros.\n",
    "\n",
    "    Args:\n",
    "      file_list: A list of paths to TFRecord files containing SequenceExamples.\n",
    "      batch_size: The number of SequenceExamples to include in each batch.\n",
    "      num_enqueuing_threads: The number of threads to use for enqueuing\n",
    "          SequenceExamples.\n",
    "      shuffle: Whether to shuffle the batches.\n",
    "\n",
    "    Returns:\n",
    "      labels: A tensor of shape [batch_size] of int64s.\n",
    "      frames: A tensor of shape [batch_size, num_steps] of floats32s. note that\n",
    "          num_steps is the max time_step of all the tensors.\n",
    "    Raises:\n",
    "      ValueError: If `shuffle` is True and `num_enqueuing_threads` is less than 2.\n",
    "    \"\"\"\n",
    "    file_queue = tf.train.string_input_producer(file_list)\n",
    "    reader = tf.TFRecordReader()\n",
    "    _, serialized_example = reader.read(file_queue)\n",
    "\n",
    "    context_features = {\n",
    "        \"label\": tf.FixedLenFeature([], dtype=tf.int64)\n",
    "    }\n",
    "    sequence_features = {\n",
    "        \"frame\": tf.FixedLenSequenceFeature([], dtype=tf.int64)\n",
    "    }\n",
    "\n",
    "    context_parsed, sequence_parsed = tf.parse_single_sequence_example(\n",
    "        serialized=serialized_example,\n",
    "        context_features=context_features,\n",
    "        sequence_features=sequence_features\n",
    "    )\n",
    "\n",
    "    labels = context_parsed['label']\n",
    "    frames = sequence_parsed['frame']\n",
    "    input_tensors = [labels, frames]\n",
    "\n",
    "    if shuffle:\n",
    "        if num_enqueuing_threads < 2:\n",
    "            raise ValueError(\n",
    "                '`num_enqueuing_threads` must be at least 2 when shuffling.')\n",
    "        shuffle_threads = int(math.ceil(num_enqueuing_threads) / 2.)\n",
    "\n",
    "        # Since there may be fewer records than SHUFFLE_MIN_AFTER_DEQUEUE, take the\n",
    "        # minimum of that number and the number of records.\n",
    "        min_after_dequeue = count_records(\n",
    "            file_list, stop_at=SHUFFLE_MIN_AFTER_DEQUEUE)\n",
    "        input_tensors = _shuffle_inputs(\n",
    "            input_tensors, capacity=QUEUE_CAPACITY,\n",
    "            min_after_dequeue=min_after_dequeue,\n",
    "            num_threads=shuffle_threads)\n",
    "\n",
    "        num_enqueuing_threads -= shuffle_threads\n",
    "\n",
    "    tf.logging.info(input_tensors)\n",
    "    return tf.train.batch(\n",
    "        input_tensors,\n",
    "        batch_size=batch_size,\n",
    "        capacity=QUEUE_CAPACITY,\n",
    "        num_threads=num_enqueuing_threads,\n",
    "        dynamic_pad=True,\n",
    "        allow_smaller_final_batch=False)\n",
    "\n",
    "\n",
    "def count_records(file_list, stop_at=None):\n",
    "    \"\"\"Counts number of records in files from `file_list` up to `stop_at`.\n",
    "\n",
    "    Args:\n",
    "      file_list: List of TFRecord files to count records in.\n",
    "      stop_at: Optional number of records to stop counting at.\n",
    "\n",
    "    Returns:\n",
    "      Integer number of records in files from `file_list` up to `stop_at`.\n",
    "    \"\"\"\n",
    "    num_records = 0\n",
    "    for tfrecord_file in file_list:\n",
    "        tf.logging.info('Counting records in %s.', tfrecord_file)\n",
    "        for _ in tf.python_io.tf_record_iterator(tfrecord_file):\n",
    "            num_records += 1\n",
    "            if stop_at and num_records >= stop_at:\n",
    "                tf.logging.info('Number of records is at least %d.', num_records)\n",
    "                return num_records\n",
    "    tf.logging.info('Total records: %d', num_records)\n",
    "    return num_records"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Counting records in ../data/seq_test1.tfrecord.\n",
      "INFO:tensorflow:Counting records in ../data/seq_test2.tfrecord.\n",
      "INFO:tensorflow:Number of records is at least 20.\n",
      "INFO:tensorflow:[<tf.Tensor 'random_shuffle_queue_Dequeue:0' shape=() dtype=int64>, <tf.Tensor 'random_shuffle_queue_Dequeue:1' shape=(?,) dtype=int64>]\n",
      "** batch 0\n",
      "[1 2 3 3 5 1 5 2 4 3]\n",
      "[[1 0 0 0 0]\n",
      " [2 2 0 0 0]\n",
      " [3 3 3 0 0]\n",
      " [3 3 3 0 0]\n",
      " [5 5 5 5 5]\n",
      " [1 0 0 0 0]\n",
      " [5 5 5 5 5]\n",
      " [2 2 0 0 0]\n",
      " [4 4 4 4 0]\n",
      " [3 3 3 0 0]]\n",
      "** batch 1\n",
      "[1 2 3 1 1 4 4 2 1 4]\n",
      "[[1 0 0 0]\n",
      " [2 2 0 0]\n",
      " [3 3 3 0]\n",
      " [1 0 0 0]\n",
      " [1 0 0 0]\n",
      " [4 4 4 4]\n",
      " [4 4 4 4]\n",
      " [2 2 0 0]\n",
      " [1 0 0 0]\n",
      " [4 4 4 4]]\n",
      "** batch 2\n",
      "[2 3 4 3 2 4 1 5 5 5]\n",
      "[[2 2 0 0 0]\n",
      " [3 3 3 0 0]\n",
      " [4 4 4 4 0]\n",
      " [3 3 3 0 0]\n",
      " [2 2 0 0 0]\n",
      " [4 4 4 4 0]\n",
      " [1 0 0 0 0]\n",
      " [5 5 5 5 5]\n",
      " [5 5 5 5 5]\n",
      " [5 5 5 5 5]]\n"
     ]
    }
   ],
   "source": [
    "tfrecord_file_names = ['../data/seq_test1.tfrecord', '../data/seq_test2.tfrecord']\n",
    "label_batch, frame_batch = get_padded_batch(tfrecord_file_names, 10, shuffle=True)\n",
    "config = tf.ConfigProto()\n",
    "config.gpu_options.allow_growth = True\n",
    "sess = tf.Session(config=config)\n",
    "tf.train.start_queue_runners(sess=sess)\n",
    "for i in range(3):\n",
    "    _frames_batch, _label_batch = sess.run([frame_batch, label_batch])\n",
    "    print('** batch %d' % i)\n",
    "    print(_label_batch)\n",
    "    print(_frames_batch)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 统计 tfrecord 文件中样本个数函数测试"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "15\n"
     ]
    }
   ],
   "source": [
    "tfrecord_file = '../data/seq_test1.tfrecord'\n",
    "num = 0\n",
    "for _ in tf.python_io.tf_record_iterator(tfrecord_file):\n",
    "    num += 1\n",
    "print(num)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "15\n"
     ]
    }
   ],
   "source": [
    "tfrecord_file = '../data/seq_test2.tfrecord'\n",
    "num = 0\n",
    "for _ in tf.python_io.tf_record_iterator(tfrecord_file):\n",
    "    num += 1\n",
    "print(num)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
